# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import abc
import os
import six
import subprocess
import time

import numpy as np
import tensorflow as tf
from tensorboard.plugins.beholder import im_util
from tensorboard.util import tb_logging

logger = tb_logging.get_logger()


class VideoWriter(object):
    """Video file writer that can use different output types.

    Each VideoWriter instance writes video files to a specified
    directory, using the first available VideoOutput from the provided
    list.
    """

    def __init__(self, directory, outputs):
        self.directory = directory
        # Filter to the available outputs
        self.outputs = [out for out in outputs if out.available()]
        if not self.outputs:
            raise IOError("No available video outputs")
        self.output_index = 0
        self.output = None
        self.frame_shape = None

    def current_output(self):
        return self.outputs[self.output_index]

    def write_frame(self, np_array):
        # Reset whenever we encounter a new frame shape.
        if self.frame_shape != np_array.shape:
            if self.output:
                self.output.close()
            self.output = None
            self.frame_shape = np_array.shape
            logger.info("Starting video with frame shape: %s", self.frame_shape)
        # Write the frame, advancing across output types as necessary.
        original_output_index = self.output_index
        for self.output_index in range(
            original_output_index, len(self.outputs)
        ):
            try:
                if not self.output:
                    new_output = self.outputs[self.output_index]
                    if self.output_index > original_output_index:
                        logger.warn(
                            "Falling back to video output %s", new_output.name()
                        )
                    self.output = new_output(self.directory, self.frame_shape)
                self.output.emit_frame(np_array)
                return
            except (IOError, OSError) as e:
                logger.warn(
                    "Video output type %s not available: %s",
                    self.current_output().name(),
                    str(e),
                )
                if self.output:
                    self.output.close()
                self.output = None
        raise IOError("Exhausted available video outputs")

    def finish(self):
        if self.output:
            self.output.close()
        self.output = None
        self.frame_shape = None
        # Reconsider failed outputs when video is manually restarted.
        self.output_index = 0


@six.add_metaclass(abc.ABCMeta)
class VideoOutput(object):
    """Base class for video outputs supported by VideoWriter."""

    # Would add @abc.abstractmethod in python 3.3+
    @classmethod
    def available(cls):
        raise NotImplementedError()

    @classmethod
    def name(cls):
        return cls.__name__

    @abc.abstractmethod
    def emit_frame(self, np_array):
        raise NotImplementedError()

    @abc.abstractmethod
    def close(self):
        raise NotImplementedError()


class PNGVideoOutput(VideoOutput):
    """Video output implemented by writing individual PNGs to disk."""

    @classmethod
    def available(cls):
        return True

    def __init__(self, directory, frame_shape):
        del frame_shape  # unused
        self.directory = directory + "/video-frames-{}".format(time.time())
        self.frame_num = 0
        tf.io.gfile.makedirs(self.directory)

    def emit_frame(self, np_array):
        filename = self.directory + "/{:05}.png".format(self.frame_num)
        im_util.write_image(np_array.astype(np.uint8), filename)
        self.frame_num += 1

    def close(self):
        pass


class FFmpegVideoOutput(VideoOutput):
    """Video output implemented by streaming to FFmpeg with .mp4 output."""

    @classmethod
    def available(cls):
        # Silently check if ffmpeg is available.
        try:
            with open(os.devnull, "wb") as devnull:
                subprocess.check_call(
                    ["ffmpeg", "-version"], stdout=devnull, stderr=devnull
                )
            return True
        except (OSError, subprocess.CalledProcessError):
            return False

    def __init__(self, directory, frame_shape):
        self.filename = directory + "/video-{}.webm".format(time.time())
        if len(frame_shape) != 3:
            raise ValueError(
                "Expected rank-3 array for frame, got %s" % str(frame_shape)
            )
        # Set input pixel format based on channel count.
        if frame_shape[2] == 1:
            pix_fmt = "gray"
        elif frame_shape[2] == 3:
            pix_fmt = "rgb24"
        else:
            raise ValueError("Unsupported channel count %d" % frame_shape[2])

        command = [
            "ffmpeg",
            "-y",  # Overwite output
            # Input options - raw video file format and codec.
            "-f",
            "rawvideo",
            "-vcodec",
            "rawvideo",
            "-s",
            "%dx%d" % (frame_shape[1], frame_shape[0]),  # Width x height.
            "-pix_fmt",
            pix_fmt,
            "-r",
            "15",  # Frame rate: arbitrarily use 15 frames per second.
            "-i",
            "-",  # Use stdin.
            "-an",  # No audio.
            # Output options - use lossless VP9 codec inside .webm.
            "-vcodec",
            "libvpx-vp9",
            "-lossless",
            "1",
            # Using YUV is most compatible, though conversion from RGB skews colors.
            "-pix_fmt",
            "yuv420p",
            self.filename,
        ]
        PIPE = subprocess.PIPE
        self.ffmpeg = subprocess.Popen(
            command, stdin=PIPE, stdout=PIPE, stderr=PIPE
        )

    def _handle_error(self):
        _, stderr = self.ffmpeg.communicate()
        bar = "=" * 40
        logger.error(
            "Error writing to FFmpeg:\n%s\n%s\n%s",
            bar,
            stderr.rstrip("\n"),
            bar,
        )

    def emit_frame(self, np_array):
        try:
            self.ffmpeg.stdin.write(np_array.tobytes())
            self.ffmpeg.stdin.flush()
        except IOError:
            self._handle_error()
            raise IOError("Failure invoking FFmpeg")

    def close(self):
        if self.ffmpeg.poll() is None:
            # Close stdin and consume and discard stderr/stdout.
            self.ffmpeg.communicate()
        self.ffmpeg = None
